import threading
from pydantic import BaseModel, Field
from typing import List, Dict, Any, Optional
from datetime import datetime
from loguru import logger

# Import from our consolidated types module
from sentientresearchagent.hierarchical_agent_framework.types import (
    TaskStatusLiteral, TaskTypeLiteral, NodeTypeLiteral,
    TaskStatus, TaskType, NodeType
)

class TaskRecord(BaseModel):
    """Represents the historical record of a task in the Knowledge Store."""
    task_id: str
    goal: str
    task_type: TaskTypeLiteral
    node_type: Optional[NodeTypeLiteral] = None
    
    input_params_dict: Dict[str, Any] = Field(default_factory=dict)
    output_content: Optional[Any] = None
    output_type_description: Optional[str] = None # e.g., "text_report_section", "plan_output"
    output_summary: Optional[str] = None

    status: TaskStatusLiteral
    timestamp_created: datetime
    timestamp_updated: datetime
    timestamp_completed: Optional[datetime] = None
    
    parent_task_id: Optional[str] = None
    # child_task_ids refers to tasks *generated by this task* if it was a PLAN task
    child_task_ids_generated: List[str] = Field(default_factory=list)
    layer: Optional[int] = None
    error_message: Optional[str] = None
    
    # For PLAN nodes, to link to the graph of their sub-tasks
    sub_graph_id: Optional[str] = None
    
    # CRITICAL FIX: Store aux_data for dependency information and other metadata
    aux_data: Dict[str, Any] = Field(default_factory=dict)
    
    # Additional fields for dependency resolution
    result: Optional[Any] = None  # Store the actual result
    planned_sub_task_ids: List[str] = Field(default_factory=list)  # For PLAN nodes

    class Config:
        use_enum_values = True # Keep this for serialization consistency

class KnowledgeStore(BaseModel):
    """A central repository for all task records."""
    records: Dict[str, TaskRecord] = Field(default_factory=dict)
    
    class Config:
        arbitrary_types_allowed = True  # Allow non-pydantic types
    
    def __init__(self, **data):
        super().__init__(**data)
        # Initialize the lock after the object is created
        object.__setattr__(self, '_lock', threading.RLock())

    def add_or_update_record_from_node(self, node: Any): # Use Any to avoid circular dep with TaskNode initially
        """Creates or updates a TaskRecord from a TaskNode with improved type handling."""
        # Ensure we have a lock (in case of deserialized objects)
        if not hasattr(self, '_lock') or self._lock is None:
            object.__setattr__(self, '_lock', threading.RLock())
        
        with self._lock:
            # Simplified enum handling - string enums work directly
            task_type_val = str(node.task_type)
            node_type_val = str(node.node_type) if node.node_type else None
            status_val = str(node.status)

            record = TaskRecord(
                task_id=node.task_id,
                goal=node.goal,
                task_type=task_type_val,
                node_type=node_type_val,
                input_params_dict=node.input_payload_dict or {},
                output_content=node.result,
                output_type_description=node.output_type_description,
                output_summary=node.output_summary,
                status=status_val,
                timestamp_created=node.timestamp_created,
                timestamp_updated=node.timestamp_updated,
                timestamp_completed=node.timestamp_completed,
                parent_task_id=node.parent_node_id,
                child_task_ids_generated=node.planned_sub_task_ids or [],
                layer=node.layer,
                error_message=node.error,
                sub_graph_id=node.sub_graph_id,  # CRITICAL FIX: Include sub_graph_id for PLAN nodes
                aux_data=node.aux_data or {},  # CRITICAL FIX: Preserve aux_data including depends_on_indices
                result=node.result,  # Store actual result for dependency context
                planned_sub_task_ids=node.planned_sub_task_ids or []  # For dependency resolution
            )
            self.records[record.task_id] = record
            logger.info(f"KnowledgeStore: Added/Updated record for {node.task_id}")

    def get_record(self, task_id: str) -> Optional[TaskRecord]:
        """Get a task record by ID."""
        return self.records.get(task_id)

    def get_records_by_status(self, status: TaskStatusLiteral) -> List[TaskRecord]:
        """Get all records with a specific status."""
        with self._lock:
            return [record for record in self.records.values() if record.status == status]

    def get_records_by_layer(self, layer: int) -> List[TaskRecord]:
        """Get all records at a specific layer."""
        return [record for record in self.records.values() if record.layer == layer]

    def get_child_records(self, parent_task_id: str) -> List[TaskRecord]:
        """Get all direct child records of a parent task."""
        with self._lock:
            return [record for record in self.records.values() 
                    if record.parent_task_id == parent_task_id]

    def clear(self):
        """Clear all records."""
        with self._lock:
            self.records.clear()
            logger.info("KnowledgeStore: All records cleared")

    def get_summary_stats(self) -> Dict[str, Any]:
        """Get summary statistics about stored records."""
        if not self.records:
            return {"total_records": 0}
            
        statuses = [record.status for record in self.records.values()]
        status_counts = {status: statuses.count(status) for status in set(statuses)}
        
        return {
            "total_records": len(self.records),
            "status_breakdown": status_counts,
            "layers": list(set(record.layer for record in self.records.values() if record.layer is not None))
        }
    
    def get_record_by_task_id(self, task_id: str) -> Optional[TaskRecord]:
        """Get a task record by ID (alias for get_record for v2 compatibility)."""
        return self.get_record(task_id)
